#include <iostream>
#include <chrono>
#include <cmath>
#include "cuda_utils.h"
#include "logging.h"
#include "utils.h"

#include "GLViewer.hpp"
#include "yolo.hpp"

#include <sl/Camera.hpp>
#include <NvInfer.h>

using namespace nvinfer1;
#define NMS_THRESH 0.4
#define CONF_THRESH 0.3

static void draw_objects(cv::Mat const& image,
        cv::Mat &res,
        sl::Objects const& objs,
        std::vector<std::vector<int>> const& colors) {
    res = image.clone();
    cv::Mat mask{image.clone()};
    for (sl::ObjectData const& obj : objs.object_list) {
        size_t const idx_color{obj.id % colors.size()};
        cv::Scalar const color{cv::Scalar(colors[idx_color][0U], colors[idx_color][1U], colors[idx_color][2U])};

        cv::Rect const rect{static_cast<int> (obj.bounding_box_2d[0U].x),
            static_cast<int> (obj.bounding_box_2d[0U].y),
            static_cast<int> (obj.bounding_box_2d[1U].x - obj.bounding_box_2d[0U].x),
            static_cast<int> (obj.bounding_box_2d[2U].y - obj.bounding_box_2d[0U].y)};
        cv::rectangle(res, rect, color, 2);

        char text[256U];
        sprintf(text, "Class %d - %.1f%%", obj.raw_label, obj.confidence);
        if (obj.mask.isInit() && obj.mask.getWidth() > 0U && obj.mask.getHeight() > 0U) {
            const cv::Mat obj_mask = slMat2cvMat(obj.mask);
            mask(rect).setTo(color, obj_mask);
        }

        int baseLine{0};
        cv::Size const label_size{cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine)};

        int const x{rect.x};
        int const y{std::min(rect.y + 1, res.rows)};

        cv::rectangle(res, cv::Rect(x, y, label_size.width, label_size.height + baseLine),{0, 0, 255}, -1);
        cv::putText(res, text, cv::Point(x, y + label_size.height), cv::FONT_HERSHEY_SIMPLEX, 0.4,{255, 255, 255}, 1);
    }
    cv::addWeighted(res, 0.5, mask, 0.8, 1, res);
}

void print(std::string msg_prefix, sl::ERROR_CODE err_code, std::string msg_suffix) {
    std::cout << "[Sample] ";
    if (err_code != sl::ERROR_CODE::SUCCESS)
        std::cout << "[Error] ";
    std::cout << msg_prefix << " ";
    if (err_code != sl::ERROR_CODE::SUCCESS) {
        std::cout << " | " << toString(err_code) << " : ";
        std::cout << toVerbose(err_code);
    }
    if (!msg_suffix.empty())
        std::cout << " " << msg_suffix;
    std::cout << std::endl;
}

cv::Rect get_rect(BBox box) {
    return cv::Rect(round(box.x1), round(box.y1), round(box.x2 - box.x1), round(box.y2 - box.y1));
}

std::vector<sl::uint2> cvt(const BBox &bbox_in) {
    std::vector<sl::uint2> bbox_out(4);
    bbox_out[0] = sl::uint2(bbox_in.x1, bbox_in.y1);
    bbox_out[1] = sl::uint2(bbox_in.x2, bbox_in.y1);
    bbox_out[2] = sl::uint2(bbox_in.x2, bbox_in.y2);
    bbox_out[3] = sl::uint2(bbox_in.x1, bbox_in.y2);
    return bbox_out;
}

std::mutex detector_mtx, img_mtx;
bool exit_detector = false;
std::vector<sl::CustomBoxObjectData> objects_in;
sl::Mat left_sl;
sl::Timestamp prev_ts = 0, custom_data_ts = 0;
sl::Resolution display_resolution;
Yolo detector;

void run_detector() {
    while (!exit_detector) {

        if (prev_ts != left_sl.timestamp) {

            // Running inference
            auto detections = detector.run(left_sl, display_resolution.height, display_resolution.width, CONF_THRESH);

            // Preparing for ZED SDK ingesting
            std::vector<sl::CustomBoxObjectData> objects_tmp;
            for (auto &it : detections) {
                sl::CustomBoxObjectData tmp;
                // Fill the detections into the correct format
                tmp.unique_object_id = sl::generate_unique_id();
                tmp.probability = it.prob;
                tmp.label = (int) it.label;
                tmp.bounding_box_2d = cvt(it.box);
                tmp.is_grounded = ((int) it.label == 0); // Only the first class (person) is grounded, that is moving on the floor plane
                // others are tracked in full 3D space
                objects_tmp.push_back(tmp);
            }

            detector_mtx.lock();
            objects_tmp.swap(objects_in);
            custom_data_ts = left_sl.timestamp;
            detector_mtx.unlock();

            prev_ts = left_sl.timestamp;
        }

        sl::sleep_ms(1);
    }
}

int main(int argc, char** argv) {
    if (argc == 1) {
        std::cout << "Usage: \n 1. ./yolo_onnx_zed -s yolov8s.onnx yolov8s.engine\n 2. ./yolo_onnx_zed -s yolov8s.onnx yolov8s.engine images:1x3x512x512\n 3. ./yolo_onnx_zed yolov8s.engine <SVO path>" << std::endl;
        return 0;
    }

    // Check Optim engine first
    if (std::string(argv[1]) == "-s" && (argc >= 4)) {
        std::string onnx_path = std::string(argv[2]);
        std::string engine_path = std::string(argv[3]);
        OptimDim dyn_dim_profile;

        if (argc == 5) {
            std::string optim_profile = std::string(argv[4]);
            bool error = dyn_dim_profile.setFromString(optim_profile);
            if (error) {
                std::cerr << "Invalid dynamic dimension argument, expecting something like 'images:1x3x512x512'" << std::endl;
                return EXIT_FAILURE;
            }
        }

        Yolo::build_engine(onnx_path, engine_path, dyn_dim_profile);
        return 0;
    }

    /// Opening the ZED camera before the model deserialization to avoid cuda context issue
    sl::Camera zed;
    sl::InitParameters init_parameters;
    init_parameters.sdk_verbose = true;
    init_parameters.depth_mode = sl::DEPTH_MODE::NEURAL;
    init_parameters.coordinate_system = sl::COORDINATE_SYSTEM::RIGHT_HANDED_Y_UP; // OpenGL's coordinate system is right_handed

    if (argc > 2) {
        std::string zed_opt = argv[2];
        if (zed_opt.find(".svo") != std::string::npos)
            init_parameters.input.setFromSVOFile(zed_opt.c_str());
    }

    // Open the camera
    auto returned_state = zed.open(init_parameters);
    if (returned_state != sl::ERROR_CODE::SUCCESS) {
        print("Camera Open", returned_state, "Exit program.");
        return EXIT_FAILURE;
    }
    zed.enablePositionalTracking();
    // Custom OD
    sl::ObjectDetectionParameters detection_parameters;
    detection_parameters.enable_tracking = true;
    detection_parameters.enable_segmentation = false; // designed to give person pixel mask with internal OD
    detection_parameters.detection_model = sl::OBJECT_DETECTION_MODEL::CUSTOM_BOX_OBJECTS;
    returned_state = zed.enableObjectDetection(detection_parameters);
    if (returned_state != sl::ERROR_CODE::SUCCESS) {
        print("enableObjectDetection", returned_state, "\nExit program.");
        zed.close();
        return EXIT_FAILURE;
    }
    auto camera_config = zed.getCameraInformation().camera_configuration;
    sl::Resolution pc_resolution(std::min((int) camera_config.resolution.width, 720), std::min((int) camera_config.resolution.height, 404));
    auto camera_info = zed.getCameraInformation(pc_resolution).camera_configuration;
    // Create OpenGL Viewer
    GLViewer viewer;
    viewer.init(argc, argv, camera_info.calibration_parameters.left_cam, true);
    // ---------


    // Creating the inference engine class
    std::string engine_name = "";
    if (argc > 0)
        engine_name = argv[1];
    else {
        std::cout << "Error: missing engine name as argument" << std::endl;
        return EXIT_FAILURE;
    }
    if (detector.init(engine_name)) {
        std::cerr << "Detector init failed!" << std::endl;
        return EXIT_FAILURE;
    }

    display_resolution = zed.getCameraInformation().camera_configuration.resolution;
    sl::Mat point_cloud;
    cv::Mat left_cv;
    sl::CustomObjectDetectionRuntimeParameters customObjectTracker_rt;
    sl::Objects objects;
    sl::Pose cam_w_pose;
    cam_w_pose.pose_data.setIdentity();
    auto zed_cuda_stream = zed.getCUDAStream();

    std::thread detection_thread(run_detector);
    
    while (viewer.isAvailable()) {
        if (zed.read() == sl::ERROR_CODE::SUCCESS) {
            // Get image for inference
            zed.retrieveImage(left_sl, sl::VIEW::LEFT, sl::MEM::GPU, sl::Resolution(0, 0), detector.stream);
            // Get the CPU image for display
            left_sl.updateCPUfromGPU(zed_cuda_stream);

            zed.grab();
            zed.retrieveMeasure(point_cloud, sl::MEASURE::XYZRGBA, sl::MEM::GPU, pc_resolution);
            zed.getPosition(cam_w_pose, sl::REFERENCE_FRAME::WORLD);

            // Get image for display
            left_cv = slMat2cvMat(left_sl);

            // wait for the detections
            while (left_sl.timestamp != custom_data_ts) sl::sleep_ms(1);

            detector_mtx.lock();
            // Send the custom detected boxes to the ZED
            zed.ingestCustomBoxObjects(objects_in);
            detector_mtx.unlock();

            // Retrieve the tracked objects, with 2D and 3D attributes
            zed.retrieveCustomObjects(objects, customObjectTracker_rt);

            // GL Viewer
            viewer.updateData(point_cloud, objects.object_list, cam_w_pose.pose_data);

            // Displaying the SDK objects
            draw_objects(left_cv, left_cv, objects, CLASS_COLORS);
            cv::imshow("ZED retrieved Objects", left_cv);
            int const key{cv::waitKey(1)};
            if (key == 'q' || key == 'Q' || key == 27)
                break;
        }
    }

    exit_detector = true;
    
    detection_thread.join();
    viewer.exit();


    return 0;
}
